# Copyright 2019 kubeflow.org.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import datetime
from distutils.util import strtobool
import logging
import yaml
import launch_crd

from kubernetes import client as k8s_client
from kubernetes import config


def yamlOrJsonStr(str):
    if str == "" or str == None:
        return None
    return yaml.safe_load(str)


PaddleJobGroup = "batch.paddlepaddle.org"
PaddleJobPlural = "paddlejobs"


class PaddleJob(launch_crd.K8sCR):
    def __init__(self, version="v1", client=None):
        super(PaddleJob, self).__init__(PaddleJobGroup, PaddleJobPlural, version, client)

    def is_expected_conditions(self, inst, expected_conditions):
        conditions = inst.get('status', {}).get("phase")
        if not conditions:
            return False, ""
        if conditions in expected_conditions:
            return True, conditions
        else:
            return False, conditions


def main(argv=None):
    parser = argparse.ArgumentParser(description='PaddleJob launcher')
    parser.add_argument('--name', type=str,
                        help='PaddleJob name.')
    parser.add_argument('--namespace', type=str,
                        default='kubeflow',
                        help='PaddleJob namespace.')
    parser.add_argument('--version', type=str,
                        default='v1',
                        help='PaddleJob version.')
    parser.add_argument('--timeoutMinutes', type=int,
                        default=60*24,
                        help='Time in minutes to wait for the PaddleJob to reach end')
    parser.add_argument('--deleteAfterDone', type=strtobool,
                        default=False,
                        help='delete PaddleJob after the job is done')

    parser.add_argument('--cleanPodPolicy', type=str,
                        default="OnCompletion",
                        help='defines whether to clean pod after job finished')
    parser.add_argument('--schedulingPolicy', type=yamlOrJsonStr,
                        default={},
                        help='defines the policy related to scheduling, for volcano')
    parser.add_argument('--intranet', type=str,
                        default="PodIP",
                        help='defines the communication mode inter pods : PodIP, Service or Host')
    parser.add_argument('--withGloo', type=int,
                        default=1,
                        help='indicate whether enable gloo, 0/1/2 for disable/enable for worker/enable for server')
    parser.add_argument('--sampleSetRef', type=yamlOrJsonStr,
                        default={},
                        help='defines the sample data set used for training and its mount path in worker pods')
    parser.add_argument('--ps', type=yamlOrJsonStr,
                        default={},
                        help='describes the spec of server base on pod template')
    parser.add_argument('--worker', type=yamlOrJsonStr,
                        default={},
                        help='describes the spec of worker base on pod template')
    parser.add_argument('--heter', type=yamlOrJsonStr,
                        default={},
                        help='describes the spec of heter worker base on pod temlate')
    parser.add_argument('--elastic', type=int,
                        default=0,
                        help='indicate the elastic level')
    args = parser.parse_args()

    logging.getLogger().setLevel(logging.INFO)
    logging.info('Generating PaddleJob template.')

    config.load_incluster_config()
    api_client = k8s_client.ApiClient()
    pdj = PaddleJob(version=args.version, client=api_client)
    inst = {
        "apiVersion": "%s/%s" % (PaddleJobGroup, args.version),
        "kind": "PaddleJob",
        "metadata": {
            "name": args.name,
            "namespace": args.namespace,
        },
        "spec": {
            "cleanPodPolicy": args.cleanPodPolicy,
            "withGloo": args.withGloo,
            "intranet": args.intranet,
        },
    }

    if args.schedulingPolicy:
        inst["spec"]["schedulingPolicy"] = args.schedulingPolicy
    if args.sampleSetRef:
        inst["spec"]["sampleSetRef"] = args.sampleSetRef
    if args.ps:
        inst["spec"]["ps"] = args.ps
    if args.worker:
        inst["spec"]["worker"] = args.worker
    if args.heter:
        inst["spec"]["heter"] = args.heter
    if args.elastic > 0:
        inst["spec"]["elastic"] = args.elastic

    create_response = pdj.create(inst)
    print("create PaddleJob have response {}".format(create_response))

    expected_conditions = ["Succeed", "Completed", "Failed", "Terminated"]
    pdj.wait_for_condition(
        args.namespace, args.name, expected_conditions,
        timeout=datetime.timedelta(minutes=args.timeoutMinutes))
    if args.deleteAfterDone:
        pdj.delete(args.name, args.namespace)


if __name__ == "__main__":
    main()
